useGeoMean = 1;

T = readtable('Aso_TPM.csv');

indexFirstData = 2;

% isolate numeric TPM data

data = table2array(T(:,indexFirstData:end));

% merge biological replicates

% strip off the replicate number ('_1, _2')
cellNames = T.Properties.VariableNames(indexFirstData:end);
cellNamesStripReplicates = cellNames;

for i=1:length(cellNames)
    cellNamesStripReplicates{i} = cellNames{i}(1:end-2);
end

[uniqueCellNames, ~, indices] = unique(cellNamesStripReplicates, 'stable');
APL_index = find(strcmp(uniqueCellNames,'MB_APL'));

% get gene IDs
geneIDs = T.gene_id;

% get gene names
namesTable = readtable('gene_names.csv');
[uniqueGeneIDs, ia, geneIndices] = unique(namesTable.gene_ID, 'stable');
uniqueGeneSymbols = cell(length(uniqueGeneIDs),1);
uniqueGeneNames = cell(length(uniqueGeneIDs),1);
for i=1:length(uniqueGeneIDs)
    uniqueGeneSymbols{i} = namesTable.gene_symbol(ia(i));
    uniqueGeneNames{i} = namesTable.gene_fullname(ia(i));
end

% Verify that the gene names are in the same order in gene_names as they
% are in Aso_TPM.csv
if nnz(~strcmp(geneIDs, uniqueGeneIDs))
    error('Gene names in input files are not in the same order!');
end

% take the log of the TPM values
logData = log10(data);
% set -Inf to -2, i.e. set a floor of TPM at 0.01 (minimum value above 0 is
% 0.1, probably can't resolve anything below that)
logData(logData<-2) = -2;

% average together all replicates
avgData = zeros(size(logData,1), length(uniqueCellNames));
for i=1:length(uniqueCellNames)
    if useGeoMean
        avgData(:,i) = mean(logData(:,indices==i),2);
    else
        avgData(:,i) = mean(data(:,indices==i),2);
    end
end

% calculate difference between APL vs non-APL neurons
nonAPLindicesUnique = setdiff(1:length(uniqueCellNames), APL_index);
nonAPL = avgData(:,nonAPLindicesUnique);
allDataNotAPL = mean(nonAPL,2);
if useGeoMean
    APLdiff = avgData(:,APL_index) - allDataNotAPL;
else
    APLdiff = log10(avgData(:,APL_index)) - log10(allDataNotAPL);
end

KChannels = [
    "Sh";
    "Shal";
    "Shaw";
    "Shab";
    "Shawl";
    "KCNQ";
    "eag";
    "Elk";
    "sei";
    "SK";
    "slo";
    "Irk1";
    "Irk2";
    "Task6";
    "Task7";
    "Ork1";
    "sand";
    "CG42594";
    "galene";
    "CG42340";
    "CG1688";
    "CG10864";
    "CG34396";
    "Hk";
    "Irk3";
    "CG34396";
    "SLO2";
    "CG43155"
    ];

VGCCs = [
    "Ca-alpha1D";
    "Ca-alpha1T";
    "Ca-beta";
    "cac"];

NaChannels = [
    "NaCP60E"; %aka DSC1
    "para";
    "tipE";];

other = [
    "ppk";
    "ppk11";
    "ppk16";
    "ppk23";
    "ppk25";
    "ppk29";
    "ppk28";
    "rpk";
    "ppk12";
    "ppk17";
    "ppk3";
    "ppk14";
    "ppk30";
    "ppk22";
    "ppk5";
    "ppk26";
    "ppk20";
    "ppk27";
    "ppk18";
    "ppk24";
    "ppk6";
    "ppk19";
    "ppk10";
    "ppk15";
    "ppk31";
    "ppk9";
    "ppk16";
    "ppk7";
    "ppk13";
    "ppk8";
    "ppk21";
    "na";
    "Nach";
    "kcc";
    "NKCC";
    "fwe";
    "iav";
    "nan";
    "nompC";
    "Orai";
    "Itpr";
    "Pkd2";
    "RyR";
    "SERCA";
    "trp";
    "TrpA1";
    "Trpml";
    "trpl";
    ];

ClChannels = [
    "ClC-a";
    "HisCl1";
    "ort";
    "Rdl";
    "Lcch3";
    "GluClalpha"];

posCtrl = [
    "Gad1";
    "VGAT";
    ];

classes = {
    posCtrl;
    other;
    ClChannels
    KChannels;
    NaChannels;
    VGCCs;
    };

classIndices = cell(length(classes),1);
channels = {};
if useGeoMean
    thresholdValue = 0;
else
    thresholdValue = 1;
end
% build up a list of index numbers for all the desired genes
for j=1:length(classes)
    classIndices{j} = zeros(length(classes{j}),1);
    for i=1:length(classes{j})
        % Note: index is the index in uniqueGeneSymbols which drives from
        % gene_names.csv. Later we will use it to index into the data in
        % Aso_TPM.csv - this is OK because we checked earlier that the gene
        % IDs are in exactly the same order in the two files. If they are
        % not in the same order, you need to re-code this section to search
        % through the geneIDs to match.
        index = find(strcmp(string(uniqueGeneSymbols),classes{j}(i)));
        if length(index)~=1
            error(strcat(classes{j}(i)," is duplicated or not present"));
        end
        % 
        % exclude genes where mean expression in non-APL MB neurons is less
        % than thresholdValue
        if allDataNotAPL(index)>thresholdValue
            classIndices{j}(i) = index;
        end
    end
    classes{j}(classIndices{j}==0) = []; %delete gene names if they are not expressed in any MB neurons
    classIndices{j}(classIndices{j}==0) = [];

    % sort the genes according to the order of APLdiff
    [~,sortIndex] = sort(APLdiff(classIndices{j}));
    classIndices{j} = classIndices{j}(sortIndex);
    classes{j} = classes{j}(sortIndex);
    
    channels = [classes{j}; channels];
end


channelData = zeros(length(channels),length(uniqueCellNames));
channelIndices = zeros(length(channels),1);
for i=1:length(channels)
    index = find(strcmp(string(uniqueGeneSymbols),channels(i)));
    if (length(index)==1)
        channelData(i,:) = avgData(index,:);
    end
    channelIndices(i) = index;
end

% test.csv:
% columns: cell types. the identities are stored in the cell array
% "uniqueCellNames" (APL is #14)
% rows: gene names, which are stored in the cell array "channels"
% Graph in Fig. 2 is made of all the non-APL neurons from this file, then
% individual APL biological replicates from APLbioreps.csv
csvwrite('logTPMnormToNonAPL.csv',channelData-repmat(allDataNotAPL(channelIndices),1,length(uniqueCellNames)));

% write APL individual biological replicates
csvwrite('APLbioreps.csv', logData(channelIndices,indices==APL_index) ... 
    - repmat(allDataNotAPL(channelIndices),1,nnz(indices==APL_index)));

% Make Fig 2B - heat map of log TPM in the following order (reorder to put DPM and APL at the end):
% KC_g
% KC_abprime
% KC_ab
% PPL1_g1pedc
% PAM_g5
% PPL1_g2a'1
% PPL1_a3
% PPL1_a'2a2
% PAM_b'2a
% PAM_g4_g1g2
% PAM_g3
% PAM_a1
% PAM_b1_b2
% MBON_g1pedc_a_b
% MBON_a3
% MBON_a1
% MBON_b1_a
% MBON_a'3
% MBON_g5b'2a
% MB_DPM
% MB_APL

writeTiffFireprintWithRange(channelData(:,[1:13 16:21 15 14])','heatMap.tif',[0 3]);
